import torch
from .conmh import  conmh_encoder, conmh_decoder

def get_model(cfg):
    if cfg.model_name == 'conmh':
        model = conmh(cfg)

    return model

def get_encoder(cfg):
    if cfg.model_name == '20241115':
        model = conmh_encoder(cfg)

    return model


def get_decoder(cfg):
    if cfg.model_name == '20241115':
        model = conmh_decoder(cfg)

    return model
